{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "categories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import fetch_20newsgroups\n",
    "twenty_train = fetch_20newsgroups(subset='train', categories=categories, shuffle=True, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "twenty_train.target_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2257"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(twenty_train.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2257"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(twenty_train.filenames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "From: sd345@city.ac.uk (Michael Collier)\n",
      "Subject: Converting images to HP LaserJet III?\n",
      "Nntp-Posting-Host: hampton\n"
     ]
    }
   ],
   "source": [
    "print(\"\\n\".join(twenty_train.data[0].split(\"\\n\")[:3]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "comp.graphics\n"
     ]
    }
   ],
   "source": [
    "print(twenty_train.target_names[twenty_train.target[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 1, 3, 3, 3, 3, 3, 2, 2, 2])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "twenty_train.target[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "comp.graphics\n",
      "comp.graphics\n",
      "soc.religion.christian\n",
      "soc.religion.christian\n",
      "soc.religion.christian\n",
      "soc.religion.christian\n",
      "soc.religion.christian\n",
      "sci.med\n",
      "sci.med\n",
      "sci.med\n"
     ]
    }
   ],
   "source": [
    "for t in twenty_train.target[:10]:\n",
    "    print(twenty_train.target_names[t])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2257, 35788)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "count_vect = CountVectorizer()\n",
    "X_train_counts = count_vect.fit_transform(twenty_train.data)\n",
    "X_train_counts.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4690"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "count_vect.vocabulary_.get(u'algorithm')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2257, 35788)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.feature_extraction.text import TfidfTransformer\n",
    "tf_transformer = TfidfTransformer(use_idf=False).fit(X_train_counts)\n",
    "X_train_tf = tf_transformer.transform(X_train_counts)\n",
    "X_train_tf.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2257, 35788)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tfidf_transformer = TfidfTransformer()\n",
    "X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)\n",
    "X_train_tfidf.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.naive_bayes import MultinomialNB\n",
    "clf = MultinomialNB().fit(X_train_tfidf, twenty_train.target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'Jeff Bezos is the world’s richest person and can afford virtually any luxury' => soc.religion.christian\n",
      "'computer gpu is fast' => comp.graphics\n"
     ]
    }
   ],
   "source": [
    "docs_new = ['Jeff Bezos is the world’s richest person and can afford virtually any luxury', 'computer gpu is fast']\n",
    "X_new_counts = count_vect.transform(docs_new)\n",
    "X_new_tfidf = tfidf_transformer.transform(X_new_counts)\n",
    "\n",
    "predicted = clf.predict(X_new_tfidf)\n",
    "\n",
    "for doc, category in zip(docs_new, predicted):\n",
    "    print('%r => %s' % (doc, twenty_train.target_names[category]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline\n",
    "text_clf = Pipeline([('vect', CountVectorizer()),\n",
    "                     ('tfidf', TfidfTransformer()),\n",
    "                     ('clf', MultinomialNB()),\n",
    "                    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Pipeline(memory=None,\n",
       "     steps=[('vect', CountVectorizer(analyzer='word', binary=False, decode_error='strict',\n",
       "        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',\n",
       "        lowercase=True, max_df=1.0, max_features=None, min_df=1,\n",
       "        ngram_range=(1, 1), preprocessor=None, stop_words=None,\n",
       "        strip...inear_tf=False, use_idf=True)), ('clf', MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True))])"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text_clf.fit(twenty_train.data, twenty_train.target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8348868175765646"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "twenty_test = fetch_20newsgroups(subset='test',\n",
    "                                 categories=categories, shuffle=True, random_state=42)\n",
    "docs_test = twenty_test.data\n",
    "predicted = text_clf.predict(docs_test)\n",
    "np.mean(predicted == twenty_test.target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Pipeline(memory=None,\n",
       "     steps=[('vect', CountVectorizer(analyzer='word', binary=False, decode_error='strict',\n",
       "        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',\n",
       "        lowercase=True, max_df=1.0, max_features=None, min_df=1,\n",
       "        ngram_range=(1, 1), preprocessor=None, stop_words=None,\n",
       "        strip...ty='l2', power_t=0.5, random_state=42, shuffle=True,\n",
       "       tol=None, verbose=0, warm_start=False))])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.linear_model import SGDClassifier\n",
    "text_clf = Pipeline([('vect', CountVectorizer()),\n",
    "                     ('tfidf', TfidfTransformer()),\n",
    "                     ('clf', SGDClassifier(loss='hinge', penalty='l2', \n",
    "                                           alpha=1e-3, random_state=42,\n",
    "                                           max_iter=5, tol=None)),\n",
    "                    ])\n",
    "text_clf.fit(twenty_train.data, twenty_train.target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9127829560585885"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predicted = text_clf.predict(docs_test)\n",
    "np.mean(predicted == twenty_test.target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                        precision    recall  f1-score   support\n",
      "\n",
      "           alt.atheism       0.95      0.81      0.87       319\n",
      "         comp.graphics       0.88      0.97      0.92       389\n",
      "               sci.med       0.94      0.90      0.92       396\n",
      "soc.religion.christian       0.90      0.95      0.93       398\n",
      "\n",
      "           avg / total       0.92      0.91      0.91      1502\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn import metrics\n",
    "print(metrics.classification_report(twenty_test.target, predicted, \n",
    "                                    target_names=twenty_test.target_names))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[258,  11,  15,  35],\n",
       "       [  4, 379,   3,   3],\n",
       "       [  5,  33, 355,   3],\n",
       "       [  5,  10,   4, 379]])"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics.confusion_matrix(twenty_test.target, predicted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV\n",
    "parameters = {'vect__ngram_range': [(1, 1), (1, 2)],\n",
    "              'tfidf__use_idf': (True, False),\n",
    "              'clf__alpha': (1e-2, 1e-3),\n",
    "             }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "gs_clf = GridSearchCV(text_clf, parameters, n_jobs=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "gs_clf = gs_clf.fit(twenty_train.data[:400], twenty_train.target[:400])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'soc.religion.christian'"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "twenty_train.target_names[gs_clf.predict(['God is love'])[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gs_clf.best_score_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "clf__alpha: 0.001\n",
      "tfidf__use_idf: True\n",
      "vect__ngram_range: (1, 1)\n"
     ]
    }
   ],
   "source": [
    "for param_name in sorted(parameters.keys()):\n",
    "    print(\"%s: %r\" % (param_name, gs_clf.best_params_[param_name]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2, 2, 2])"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "twenty_test.target[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
